Decision Trees & Random forest

Tree-Based Methods

In this section, we describe tree-based methods for regression. These involve stratifying or segmenting the predictor space into a number of simple regions. In order to make a prediction for a given observation, we typically use the mean or the mode of the observations in the region to which it belongs. Since the set of splitting rules used to segment the predictor space can be summarized in a tree, these types of approaches are known as decision tree methods.

Tree-based methods are simple and useful for interpretation. However, they typically are not competitive with the best supervised learning approaches in terms of prediction accuracy. Hence, in this chapter we also introduce the random forests method that involves producing multiple trees which are then combined to yield a single consensus prediction. We will see that combining a large number of trees can often result in dramatic improvements in prediction accuracy, at the expense of some loss in interpretation.

Decision trees can be applied to both regression and classification problems.

Let’s look at the following example on “Baseball salary” data. How would you stratify this?

salary is color coded from low (blue,green) to high (yellow, red).

We use the Hitters data set to predict a baseball player’s Salary based on Years (the number of years that he has played in the major leagues) and Hits (the number of hits that he made in the previous year).

Overall, the tree stratifies or segments the players into three regions of predictor space: \(R_1=\{\mathrm{X} \mid\) Years \(<4.5\}\), \(R_2=\{\mathrm{X} \mid\) Years \(>=4.5\), Hits \(<117.5\}\), and \(R_3=\{\mathrm{X} \mid\) Years \(>=4.5\), Hits \(>=117.5\}\)

Prediction via Stratification of the Feature Space

We now discuss the process of building a regression tree. Roughly speaking, there are two steps.

  1. We divide the predictor space - that is, the set of possible values for \(X_1, X_2, \ldots, X_p\) - into \(J\) distinct and non-overlapping regions, \(R_1, R_2, \ldots, R_J\) known as terminal nodes or leaves of the tree. Decision trees are typically drawn upside down, in the sense that the leaves are at the bottom of the tree. The points along the tree where the predictor space is split are referred to as internal nodes. We refer to the segments of the trees that connect the nodes as branches.

  2. For every observation that falls into the region \(R_k\), we make the same prediction, which is simply the mean of the response values for the observations in \(R_k\).

How do we construct the regions \(R_1, R_2, \ldots, R_J\) ? In theory, the regions could have any shape. However, we choose to divide the predictor space into high-dimensional rectangles, or boxes, for simplicity and for ease of interpretation of the resulting predictive model. The goal is to find boxes \(R_1, R_2, \ldots, R_J\) that minimize the RSS, given by \[ \sum_{k=1}^J \sum_{i \varepsilon R_k}\left(y_i-\hat{y}_{R_k}\right)^2 \] where \(\hat{y}_{R_k}\) is the mean response for the observations within the \(k\) th box. Unfortunately, it is computationally infeasible to consider every possible partition of the feature space into \(J\) boxes. For this reason, we take a top-down, greedy approach that is known as recursive binary splitting. The approach is top-down because it begins at the top of the tree (at which point all observations belong to a single region) and then successively splits the predictor space; each split is indicated via two new branches further down on the tree. It is greedy because at each step of the tree-building process, the best split is made at that particular step, rather than looking ahead and picking a split that will lead to a better tree in some future step. In order to perform recursive binary splitting, we first select the predictor \(X_j\) and the cut-point \(s\) such that splitting the predictor space into the regions \(\left\{X \mid X_k<s\right\}\) and \(\left\{X \mid X_k \geq s\right\}\) leads to the greatest possible reduction in RSS. That is, we consider all predictors \(X_1, X_2, \ldots, X_p\) and all possible values of the cut-point \(s\) for each one of the predictors, and then choose the predictor and cut-point such that the resulting tree has the lowest RSS.

In greater detail, for any \(k\) and \(s\), we define the pair of half-planes \[ R_1(k, s)=\left\{X \mid X_k<s\right\} \] and \[ R_2(k, s)=\left\{X \mid X_k \geq s\right\} \] and we seek the value of \(k\) and \(s\) that minimize the equation \[ \sum_{i: x_i \varepsilon R_{1(k, s)}}\left(y_i-\hat{y}_{R_1}\right)^2+\sum_{i: x_i \varepsilon R_{2(k, s)}}\left(y_i-\hat{y}_{R_2}\right)^2 \] where \(\hat{y}_{R_1}\) is the mean response for the observations in \(R_1(k, s)\), and \(\hat{y}_{R_2}\) is the mean response for the observations in \(R_2(k, s)\). Finding the values of \(k\) and \(s\) that minimize this equation can be done quite quickly, especially when the number of features \(p\) is not too large. Next, we repeat the process, looking for the best predictor and best cut-point in order to split the data further so as to minimize the RSS within each of the resulting regions. However, this time, instead of splitting the entire predictor space, we split one of the two previously identified regions. We now have three regions. Again, we look to split one of these three regions further, so as to minimize the RSS. The process continues until a stopping criterion is reached; for instance, we may continue until no region contains more than five observations. Once the regions \(R_1, R_2, \ldots, R_J\) have been created, we predict the response for a given test observation using the mean of the training observations in the region to which that observation belongs.

Top Left: A partition of two-dimensional feature space that could not result from recursive binary splitting. Top Right: The output of recursive binary splitting on a two-dimensional example. Bottom Left: A tree corresponding to the partition in the top right panel. Bottom Right: A perspective plot of the prediction surface corresponding to that tree
Top Left: A partition of two-dimensional feature space that could not result from recursive binary splitting. Top Right: The output of recursive binary splitting on a two-dimensional example. Bottom Left: A tree corresponding to the partition in the top right panel. Bottom Right: A perspective plot of the prediction surface corresponding to that tree

Tree Pruning

The process described above is likely to overfit the data because the resulting tree might be too complex. A smaller tree with fewer splits (that is, fewer regions \(R_1, R_2, \ldots, R_J\) ) might lead to lower variance and better interpretation at the cost of a little bias. One possible alternative to the process described above is to build the tree only so long as the decrease in the RSS due to each split exceeds some (high) threshold. This strategy will result in smaller trees, but is too short-sighted since a seemingly worthless split early on in the tree might be followed by a very good split - that is, a split that leads to a large reduction in RSS later on. Therefore, a better strategy is to grow a very large tree \(T_0\), and then prune it back in order to obtain a subtree. How do we determine the best way to prune the tree? Intuitively, our goal is to select a subtree that leads to the lowest test error rate. Given a subtree, we can estimate its test error using cross-validation or the validation set approach. However, estimating the cross-validation error for every possible subtree would be too cumbersome, since there is an extremely large number of possible subtrees. Instead, we need a way to select a small set of subtrees for consideration.

Cost complexity pruning - also known as weakest link pruning - gives us a way to do just this. Rather than considering every possible subtree, we consider a sequence of trees indexed by a nonnegative tuning parameter \(\alpha\). For each value of \(\alpha\) there corresponds a subtree \(T \subset T_0\) such that

\[ \sum_{m=1}^{|T|} \sum_{i: x_i \varepsilon R_m}\left(y_i-\hat{y}_{R_m}\right)^2+\alpha|T| \]

is as small as possible. Here \(|T|\) indicates the number of terminal nodes of the tree \(T, R_m\) is the rectangle (i.e. the subset of predictor space) corresponding to the \(m\) th terminal node, and \(\hat{y}_{R_m}\) is the predicted response associated with \(R_m\) - that is, the mean of the observations in \(R_m\). The tuning parameter \(\alpha\) controls a trade-off between the subtree’s complexity and its fit to the training data. When \(\alpha=0\), then the subtree \(T\) will simply equal \(T_0\). However, as \(\alpha\) increases, there is a price to pay for having a tree with many terminal nodes, and so the above quantity will tend to be minimized for a smaller subtree. This equation is reminiscent of the lasso regression. It turns out that as we increase \(\alpha\) from zero, branches get pruned from the tree in a nested and predictable fashion, so obtaining the whole sequence of subtrees as a function of \(\alpha\) is easy. We can select a value of \(\alpha\) using a validation set or using cross-validation. We then return to the full data set and obtain the subtree corresponding to \(\alpha\). This process is summarized in the following algorithm:

Building a Regression Tree

  1. Use recursive binary splitting to grow a large tree on the training data, stopping only when each terminal node has fewer than some minimum number of observations.

  2. Apply cost complexity pruning to the large tree in order to obtain a sequence of best subtrees, as a function of \(\alpha\).

  3. Use K-fold cross-validation to choose \(\alpha\). That is, divide the training observations into \(K\) folds. For each \(k=1, \ldots, K\) : a). Repeat Steps 1 and 2 on all but the \(k\) th fold of the training data.

    b). Evaluate the mean squared prediction error on the data in the left-out \(k\) th fold, as a function of \(\alpha\).

    Average the results for each value of \(\alpha\), and pick \(\alpha\) to minimize the average error.

  4. Return the subtree from Step 2 that corresponds to the chosen value of \(\alpha\)

Regression tree analysis for the Hitters data. The unpruned tree that results from top-down greedy splitting on the training data is shown.
Regression tree analysis for the Hitters data. The unpruned tree that results from top-down greedy splitting on the training data is shown.
Regression tree analysis for the Hitters data. The training, cross-validation, and test MSE are shown as a function of the number of termi- nal nodes in the pruned tree. Standard error bands are displayed. The minimum cross-validation error occurs at a tree size of three.
Regression tree analysis for the Hitters data. The training, cross-validation, and test MSE are shown as a function of the number of termi- nal nodes in the pruned tree. Standard error bands are displayed. The minimum cross-validation error occurs at a tree size of three.

Example: Boston Housing Data: Regression Tree

Here we fit a regression tree to the Boston data set. First, we create a training set, and fit the tree to the training data.

library(MASS)
library(tree)

set.seed(1)
train = sample(1:nrow(Boston), nrow(Boston)/2)

Fitting a model:

tree.boston=tree(medv~.,Boston,subset=train)
summary(tree.boston)
## 
## Regression tree:
## tree(formula = medv ~ ., data = Boston, subset = train)
## Variables actually used in tree construction:
## [1] "rm"    "lstat" "crim"  "age"  
## Number of terminal nodes:  7 
## Residual mean deviance:  10.38 = 2555 / 246 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -10.1800  -1.7770  -0.1775   0.0000   1.9230  16.5800

Notice that the output of summary() indicates that only four of the variables have been used in constructing the tree. In the context of a regression tree, the deviance is simply the sum of squared errors for the tree.

We now plot the tree.

plot(tree.boston)
text(tree.boston,pretty=0)

The variable lstat measures the percentage of individuals with lower socioeconomic status. The tree indicates that lower values of lstat correspond to more expensive houses. The tree predicts a median house price of $46, 400 for larger homes in suburbs in which residents have high socioeconomic status.

Additionally you can use rpart package to fit a tree model as well. With this package you can get nicer data visualizations.

library(rpart)
library(rattle)

tree.boston2 = rpart(medv~.,Boston,subset = train)

fancyRpartPlot(tree.boston2)## you need 'rattle' to use this

library(sparkline)
library(visNetwork)# to make this viz, you need to use the package "rpart"
visTree(tree.boston2) # another viz

You can use the above if you need to get better visualizations.

But we will continue to use the tree package.

Now we use the cv.tree() function to see whether pruning the tree will improve performance.

cv.boston=cv.tree(tree.boston)
plot(cv.boston$size,cv.boston$dev,type='b')

In this case, the most complex tree is selected by cross-validation. However, if we wish to prune the tree, we could do so as follows, using the prune.tree() function:

prune.boston=prune.tree(tree.boston,best=5)
plot(prune.boston)
text(prune.boston,pretty=0)

In keeping with the cross-validation results, we use the un-pruned tree to make predictions on the test set.

yhat=predict(tree.boston,newdata=Boston[-train,])
boston.test=Boston[-train,"medv"]#testing y values

mean((yhat-boston.test)^2) #MSE testing
## [1] 35.28688

In other words, the test set MSE associated with the regression tree is 35.28. The square root of the MSE is therefore around 5.93, indicating that this model leads to test predictions that are within around $5,930 of the true median home value for the suburb.

Advantages and Disadvantages of Trees

  • Decision trees for regression and classification have a number of advantages over the more classical approaches seen in early lessons.

  • Trees are very easy to explain to people. In fact, they are even easier to explain than linear regression!

  • Some people believe that decision trees more closely mirror human decision-making than do the regression and classification approaches seen in previous chapters.

  • Trees can be displayed graphically, and are easily interpreted even by a non-expert (especially if they are small).

  • Trees can easily handle qualitative predictors without the need to create dummy variables.

  • Unfortunately, trees generally do not have the same level of predictive accuracy as some of the other regression and classification approaches seen in this book.

  • Additionally, trees can be very non-robust. In other words, a small change in the data can cause a large change in the final estimated tree.

Random Forests

Random forests provide an improvement over bagged trees by way of a small tweak that decorrelates the trees. As in bagging, we build a number of decision trees on bootstrapped samples. But, when building these decision trees, each time a split in a tree is considered, a random sample of \(m\) predictors is chosen as split candidates from the full set of \(p\) predictors. The split is allowed to use only one of those \(m\) predictors. A fresh sample of \(m\) predictors is taken at each split, and typically we choose \(m \cong \sqrt{p}\). In other words, in building a random forest, at each split in the tree, the algorithm is not even allowed to consider a majority of the available predictors. This may sound crazy, but it has a clever rationale.

Suppose that there is one very strong predictor in the data set, along with a number of other moderately strong predictors. Then in the collection of bagged trees, most or all of the trees will use this strong predictor in the top split. Consequently, all of the bagged trees will look quite similar to each other. Hence the predictions from the bagged trees will be highly correlated. Unfortunately, averaging many highly correlated quantities does not lead to as large of a reduction in variance as averaging many uncorrelated quantities. In particular, this means that bagging will not lead to a substantial reduction in variance over a single tree in this setting. Random forests overcome this problem by forcing each split to consider only a subset of the predictors. Therefore, on average \((p-m) / p\) of the splits will not even consider the strong predictor, and so other predictors will have more of a chance. We can think of this process as decorrelating the trees, thereby making the average of the resulting trees less variable and hence more reliable.

Example: Boston Housing Data Cts…

Growing a random forest proceeds in exactly the same way, except that we use a smaller value of the mtry argument. By default, randomForest() uses \(p/3\) variables when building a random forest of regression trees, and \(\sqrt(p)\) variables when building a random forest of classification trees. Here we use mtry = 6

library(randomForest)

set.seed(1)
rf.boston=randomForest(medv~.,data=Boston,subset=train,mtry=6,importance=TRUE)
yhat.rf = predict(rf.boston,newdata=Boston[-train,])
mean((yhat.rf-boston.test)^2) #test set MSE
## [1] 19.62021
plot(rf.boston)

Variable importance

importance(rf.boston)
##           %IncMSE IncNodePurity
## crim    16.697017    1076.08786
## zn       3.625784      88.35342
## indus    4.968621     609.53356
## chas     1.061432      52.21793
## nox     13.518179     709.87339
## rm      32.343305    7857.65451
## age     13.272498     612.21424
## dis      9.032477     714.94674
## rad      2.878434      95.80598
## tax      9.118801     364.92479
## ptratio  8.467062     823.93341
## black    7.579482     275.62272
## lstat   27.129817    6027.63740

Two measures of variable importance are reported. First is the mean decrease of accuracy in predictions and when the variable values are permuted. thus breaking the relationship. The latter is a measure of the total decrease in node impurity that results from splits over that variable, averaged over all trees . In the case of regression trees, the node impurity is measured by the training RSS, and for classification trees by the deviance. Plots of these importance measures can be produced using the varImpPlot() function

varImpPlot(rf.boston)

Example: Carseats data set: Classification tree

For this example let’s use the carseats data set.

library(tree)
library(ISLR)
library(psych)

attach(Carseats) 
summary(Carseats)
##      Sales          CompPrice       Income        Advertising    
##  Min.   : 0.000   Min.   : 77   Min.   : 21.00   Min.   : 0.000  
##  1st Qu.: 5.390   1st Qu.:115   1st Qu.: 42.75   1st Qu.: 0.000  
##  Median : 7.490   Median :125   Median : 69.00   Median : 5.000  
##  Mean   : 7.496   Mean   :125   Mean   : 68.66   Mean   : 6.635  
##  3rd Qu.: 9.320   3rd Qu.:135   3rd Qu.: 91.00   3rd Qu.:12.000  
##  Max.   :16.270   Max.   :175   Max.   :120.00   Max.   :29.000  
##    Population        Price        ShelveLoc        Age          Education   
##  Min.   : 10.0   Min.   : 24.0   Bad   : 96   Min.   :25.00   Min.   :10.0  
##  1st Qu.:139.0   1st Qu.:100.0   Good  : 85   1st Qu.:39.75   1st Qu.:12.0  
##  Median :272.0   Median :117.0   Medium:219   Median :54.50   Median :14.0  
##  Mean   :264.8   Mean   :115.8                Mean   :53.32   Mean   :13.9  
##  3rd Qu.:398.5   3rd Qu.:131.0                3rd Qu.:66.00   3rd Qu.:16.0  
##  Max.   :509.0   Max.   :191.0                Max.   :80.00   Max.   :18.0  
##  Urban       US     
##  No :118   No :142  
##  Yes:282   Yes:258  
##                     
##                     
##                     
## 
pairs.panels(Carseats)

Let’s create a binary indicator variable from the continuous variable

(High <- factor(ifelse(Sales <= 8, "No", "Yes")))
##   [1] Yes Yes Yes No  No  Yes No  Yes No  No  Yes Yes No  Yes Yes Yes No  Yes
##  [19] Yes Yes No  Yes No  No  Yes Yes Yes No  No  No  Yes Yes No  Yes No  Yes
##  [37] Yes No  No  No  No  No  Yes No  No  No  Yes No  No  Yes No  No  No  No 
##  [55] No  No  Yes No  No  No  Yes No  No  Yes No  No  Yes Yes Yes No  Yes No 
##  [73] No  Yes No  Yes Yes No  No  Yes Yes No  Yes No  No  Yes Yes Yes No  No 
##  [91] No  No  No  Yes Yes No  Yes No  Yes No  No  No  No  No  No  No  No  Yes
## [109] No  Yes Yes No  No  No  Yes Yes No  Yes No  No  No  Yes No  Yes Yes Yes
## [127] Yes No  No  No  Yes No  Yes No  No  No  No  No  Yes Yes No  No  No  No 
## [145] Yes Yes No  Yes No  Yes Yes Yes No  No  No  No  No  Yes Yes Yes No  No 
## [163] No  No  Yes No  No  No  No  Yes Yes Yes Yes No  No  No  No  Yes Yes No 
## [181] No  No  No  No  Yes Yes Yes No  Yes Yes Yes No  No  Yes No  No  No  No 
## [199] No  No  No  No  No  No  Yes No  No  Yes No  No  No  Yes Yes Yes No  No 
## [217] No  No  Yes Yes Yes No  No  No  No  No  No  Yes No  Yes No  Yes Yes Yes
## [235] Yes No  Yes Yes No  No  Yes Yes No  No  Yes Yes No  No  No  No  Yes No 
## [253] Yes No  Yes No  No  Yes No  No  No  No  No  No  No  No  Yes No  No  No 
## [271] Yes No  Yes Yes No  No  No  No  No  No  No  Yes No  No  No  No  No  No 
## [289] No  Yes Yes No  Yes Yes Yes No  Yes No  Yes Yes Yes No  No  Yes Yes Yes
## [307] No  No  Yes Yes Yes No  No  Yes No  No  Yes No  Yes No  No  No  Yes Yes
## [325] No  Yes No  No  No  Yes No  Yes No  No  No  No  No  Yes No  Yes No  No 
## [343] No  No  Yes No  Yes No  Yes Yes Yes Yes Yes Yes No  No  No  Yes No  No 
## [361] Yes Yes No  Yes Yes No  No  Yes Yes Yes No  Yes No  No  Yes No  Yes No 
## [379] No  No  Yes No  No  Yes Yes No  No  Yes Yes Yes No  No  No  No  No  Yes
## [397] No  No  No  Yes
## Levels: No Yes
Carseats=data.frame(Carseats,High)

Fit a tree to predict high variable using all variables except sales

tree.carseats=tree(High~.-Sales,Carseats)
summary(tree.carseats)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats)
## Variables actually used in tree construction:
## [1] "ShelveLoc"   "Price"       "Income"      "CompPrice"   "Population" 
## [6] "Advertising" "Age"         "US"         
## Number of terminal nodes:  27 
## Residual mean deviance:  0.4575 = 170.7 / 373 
## Misclassification error rate: 0.09 = 36 / 400

Training error rate is 9%. see page 325 in ISLR for residual mean deviance (RMD) formula: smaller RMD the better

Plot the tree

plot(tree.carseats)
text(tree.carseats,pretty=0)

tree.carseats
## node), split, n, deviance, yval, (yprob)
##       * denotes terminal node
## 
##   1) root 400 541.500 No ( 0.59000 0.41000 )  
##     2) ShelveLoc: Bad,Medium 315 390.600 No ( 0.68889 0.31111 )  
##       4) Price < 92.5 46  56.530 Yes ( 0.30435 0.69565 )  
##         8) Income < 57 10  12.220 No ( 0.70000 0.30000 )  
##          16) CompPrice < 110.5 5   0.000 No ( 1.00000 0.00000 ) *
##          17) CompPrice > 110.5 5   6.730 Yes ( 0.40000 0.60000 ) *
##         9) Income > 57 36  35.470 Yes ( 0.19444 0.80556 )  
##          18) Population < 207.5 16  21.170 Yes ( 0.37500 0.62500 ) *
##          19) Population > 207.5 20   7.941 Yes ( 0.05000 0.95000 ) *
##       5) Price > 92.5 269 299.800 No ( 0.75465 0.24535 )  
##        10) Advertising < 13.5 224 213.200 No ( 0.81696 0.18304 )  
##          20) CompPrice < 124.5 96  44.890 No ( 0.93750 0.06250 )  
##            40) Price < 106.5 38  33.150 No ( 0.84211 0.15789 )  
##              80) Population < 177 12  16.300 No ( 0.58333 0.41667 )  
##               160) Income < 60.5 6   0.000 No ( 1.00000 0.00000 ) *
##               161) Income > 60.5 6   5.407 Yes ( 0.16667 0.83333 ) *
##              81) Population > 177 26   8.477 No ( 0.96154 0.03846 ) *
##            41) Price > 106.5 58   0.000 No ( 1.00000 0.00000 ) *
##          21) CompPrice > 124.5 128 150.200 No ( 0.72656 0.27344 )  
##            42) Price < 122.5 51  70.680 Yes ( 0.49020 0.50980 )  
##              84) ShelveLoc: Bad 11   6.702 No ( 0.90909 0.09091 ) *
##              85) ShelveLoc: Medium 40  52.930 Yes ( 0.37500 0.62500 )  
##               170) Price < 109.5 16   7.481 Yes ( 0.06250 0.93750 ) *
##               171) Price > 109.5 24  32.600 No ( 0.58333 0.41667 )  
##                 342) Age < 49.5 13  16.050 Yes ( 0.30769 0.69231 ) *
##                 343) Age > 49.5 11   6.702 No ( 0.90909 0.09091 ) *
##            43) Price > 122.5 77  55.540 No ( 0.88312 0.11688 )  
##              86) CompPrice < 147.5 58  17.400 No ( 0.96552 0.03448 ) *
##              87) CompPrice > 147.5 19  25.010 No ( 0.63158 0.36842 )  
##               174) Price < 147 12  16.300 Yes ( 0.41667 0.58333 )  
##                 348) CompPrice < 152.5 7   5.742 Yes ( 0.14286 0.85714 ) *
##                 349) CompPrice > 152.5 5   5.004 No ( 0.80000 0.20000 ) *
##               175) Price > 147 7   0.000 No ( 1.00000 0.00000 ) *
##        11) Advertising > 13.5 45  61.830 Yes ( 0.44444 0.55556 )  
##          22) Age < 54.5 25  25.020 Yes ( 0.20000 0.80000 )  
##            44) CompPrice < 130.5 14  18.250 Yes ( 0.35714 0.64286 )  
##              88) Income < 100 9  12.370 No ( 0.55556 0.44444 ) *
##              89) Income > 100 5   0.000 Yes ( 0.00000 1.00000 ) *
##            45) CompPrice > 130.5 11   0.000 Yes ( 0.00000 1.00000 ) *
##          23) Age > 54.5 20  22.490 No ( 0.75000 0.25000 )  
##            46) CompPrice < 122.5 10   0.000 No ( 1.00000 0.00000 ) *
##            47) CompPrice > 122.5 10  13.860 No ( 0.50000 0.50000 )  
##              94) Price < 125 5   0.000 Yes ( 0.00000 1.00000 ) *
##              95) Price > 125 5   0.000 No ( 1.00000 0.00000 ) *
##     3) ShelveLoc: Good 85  90.330 Yes ( 0.22353 0.77647 )  
##       6) Price < 135 68  49.260 Yes ( 0.11765 0.88235 )  
##        12) US: No 17  22.070 Yes ( 0.35294 0.64706 )  
##          24) Price < 109 8   0.000 Yes ( 0.00000 1.00000 ) *
##          25) Price > 109 9  11.460 No ( 0.66667 0.33333 ) *
##        13) US: Yes 51  16.880 Yes ( 0.03922 0.96078 ) *
##       7) Price > 135 17  22.070 No ( 0.64706 0.35294 )  
##        14) Income < 46 6   0.000 No ( 1.00000 0.00000 ) *
##        15) Income > 46 11  15.160 Yes ( 0.45455 0.54545 ) *

By using the above code we can see the fit and prediction at each branch.

If we just type the name of the tree object, R prints output corresponding to each branch of the tree. R displays the split criterion (e.g. Price<92.5), the number of observations in that branch, the deviance, the overall prediction for the branch (Yes or No), and the fraction of observations in that branch that take on values of Yes and No. Branches that lead to terminal nodes are indicated using asterisks.

Now, split data into test and train to properly assess performance

set.seed(2)
 #indices for 200 observations randomly selected as training samples
train=sample(1:nrow(Carseats), 200)
Carseats.test=Carseats[-train,] #create test data set which contains other observations
High.test=High[-train] #test data sets of only response variables

fit tree on training data by specifying subset=train

tree.carseats=tree(High~.-Sales,Carseats,subset=train) #fit on training data
summary(tree.carseats)
## 
## Classification tree:
## tree(formula = High ~ . - Sales, data = Carseats, subset = train)
## Variables actually used in tree construction:
## [1] "Price"       "Population"  "ShelveLoc"   "Age"         "Education"  
## [6] "CompPrice"   "Advertising" "Income"      "US"         
## Number of terminal nodes:  21 
## Residual mean deviance:  0.5543 = 99.22 / 179 
## Misclassification error rate: 0.115 = 23 / 200

Now let’s predict:

#predict on test data
(tree.pred=predict(tree.carseats,Carseats.test,type="class")) 
##   [1] Yes No  No  Yes No  No  Yes Yes Yes No  No  No  No  Yes Yes No  Yes No 
##  [19] No  No  No  No  No  No  No  No  No  No  Yes No  No  No  No  No  Yes No 
##  [37] Yes Yes No  Yes Yes Yes No  No  No  No  No  Yes No  No  No  No  Yes No 
##  [55] No  No  No  No  No  No  Yes No  No  No  No  No  Yes Yes No  Yes Yes No 
##  [73] Yes Yes Yes No  No  No  No  Yes Yes Yes No  No  No  Yes Yes No  Yes No 
##  [91] No  No  No  No  No  Yes No  No  No  No  No  Yes Yes No  No  Yes Yes Yes
## [109] Yes No  Yes No  No  Yes No  No  Yes No  No  Yes No  No  No  No  No  No 
## [127] No  No  No  No  Yes No  Yes Yes No  Yes No  No  Yes No  Yes Yes No  No 
## [145] No  No  Yes No  No  No  No  No  Yes No  Yes No  Yes No  No  Yes No  No 
## [163] No  No  No  Yes No  No  Yes No  Yes No  No  No  Yes No  Yes Yes No  No 
## [181] No  No  No  No  No  No  Yes No  No  No  Yes No  No  No  No  No  No  No 
## [199] Yes No 
## Levels: No Yes

Now you can see that, unlike the regression tree we cannot calculate a MSE value for the classification tree. Because we don’t get numerical values as the predictions.

Therefore, we need to use another way to evaluate the model.

Calculating error terms for classification problems

Consider the following example:

This table is called Confusion Matrix:

Define the terms:

  • True Positives (TP): These are cases in which we predicted yes (they have the disease), and they do have the disease.

  • True Negatives (TN): We predicted no, and they don’t have the disease.

  • False Positives (FP): We predicted yes, but they don’t actually have the disease. (Also known as a “Type I error.”)

  • False Negatives (FN): We predicted no, but they actually do have the disease. (Also known as a “Type II error.”)

This is a list of rates that are often computed from a confusion matrix for a binary classifier:

  • Accuracy: Overall, how often is the classifier correct?

    • (TP+TN)/total = (100+50)/165 = 0.91
  • Misclassification Rate: Overall, how often is it wrong?

    • (FP+FN)/total = (10+5)/165 = 0.09

    • equivalent to ‘1 - Accuracy’

    • also known as “Error Rate”

  • True Positive Rate: When it’s actually yes, how often does it predict yes?

    • TP/actual yes = 100/105 = 0.95

    • also known as “Sensitivity” or “Recall”

  • False Positive Rate: When it’s actually no, how often does it predict yes?

    • FP/actual no = 10/60 = 0.17
  • True Negative Rate: When it’s actually no, how often does it predict no?

    • TN/actual no = 50/60 = 0.83

    • equivalent to ‘1 - False Positive Rate’

    • also known as “Specificity”

  • Precision: When it predicts yes, how often is it correct?

    • TP/predicted yes = 100/110 = 0.91
  • Prevalence: How often does the yes condition actually occur in our sample?

    • actual yes/total = 105/165 = 0.64

Example: Carseats data set Cts…

Now let’s calculate the confusion matrix for this.

table(tree.pred,High.test)#look at predicted vs. actual values on test data.
##          High.test
## tree.pred  No Yes
##       No  104  33
##       Yes  13  50

(104+50)/200 correctly predicts 77% of test observations.

Or you can use the caret package to calculate the confusion matrix.

library(caret)
confusionMatrix(tree.pred,High.test)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  104  33
##        Yes  13  50
##                                           
##                Accuracy : 0.77            
##                  95% CI : (0.7054, 0.8264)
##     No Information Rate : 0.585           
##     P-Value [Acc > NIR] : 2.938e-08       
##                                           
##                   Kappa : 0.5091          
##                                           
##  Mcnemar's Test P-Value : 0.005088        
##                                           
##             Sensitivity : 0.8889          
##             Specificity : 0.6024          
##          Pos Pred Value : 0.7591          
##          Neg Pred Value : 0.7937          
##              Prevalence : 0.5850          
##          Detection Rate : 0.5200          
##    Detection Prevalence : 0.6850          
##       Balanced Accuracy : 0.7456          
##                                           
##        'Positive' Class : No              
## 

Now let’s purne the tree.

set.seed(9)
cv.carseats=cv.tree(tree.carseats,FUN=prune.misclass)
names(cv.carseats)
## [1] "size"   "dev"    "k"      "method"

The function cv.tree() performs cross-validation in order to cv.tree() determine the optimal level of tree complexity; cost complexity pruning is used in order to select a sequence of trees for consideration. We use the argument FUN=prune.misclass in order to indicate that we want the classification error rate to guide the cross-validation and pruning process, rather than the default for the cv.tree() function, which is deviance. The cv.tree() function reports the number of terminal nodes of each tree considered (size) as well as the corresponding error rate and the value of the cost-complexity parameter used default is 10-fold cross validation.

cv.carseats
## $size
## [1] 21 19 14  9  8  5  3  2  1
## 
## $dev
## [1] 72 73 72 72 77 77 78 83 84
## 
## $k
## [1] -Inf  0.0  1.0  1.4  2.0  3.0  4.0  9.0 18.0
## 
## $method
## [1] "misclass"
## 
## attr(,"class")
## [1] "prune"         "tree.sequence"

Note that, despite the name, dev corresponds to the cross-validation error rate in this instance. The tree with 9 terminal nodes results in the lowest cross-validation error rate, with 72 cross-validation errors.

We plot the error rate as a function of both size and k.

par(mfrow=c(1,2))
plot(cv.carseats$size,cv.carseats$dev,type="b")
plot(cv.carseats$k,cv.carseats$dev,type="b")

We now apply the prune.misclass() function in order to prune the tree to obtain the nine-node tree.

prune.carseats=prune.misclass(tree.carseats,best=9)
plot(prune.carseats)
text(prune.carseats,pretty=0)

How well does this pruned tree perform on the test data set? Once again, we apply the predict() function.

tree.pred=predict(prune.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
##          High.test
## tree.pred No Yes
##       No  97  25
##       Yes 20  58
confusionMatrix(tree.pred,High.test)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction No Yes
##        No  97  25
##        Yes 20  58
##                                           
##                Accuracy : 0.775           
##                  95% CI : (0.7108, 0.8309)
##     No Information Rate : 0.585           
##     P-Value [Acc > NIR] : 1.206e-08       
##                                           
##                   Kappa : 0.5325          
##                                           
##  Mcnemar's Test P-Value : 0.551           
##                                           
##             Sensitivity : 0.8291          
##             Specificity : 0.6988          
##          Pos Pred Value : 0.7951          
##          Neg Pred Value : 0.7436          
##              Prevalence : 0.5850          
##          Detection Rate : 0.4850          
##    Detection Prevalence : 0.6100          
##       Balanced Accuracy : 0.7639          
##                                           
##        'Positive' Class : No              
## 

(97+58)/200 #correctly predicts 77.5% of test cases. Now 77.5 % of the test observations are correctly classified, so not only has the pruning process produced a more interpretable tree, but it has also improved the classification accuracy.

If we increase the value of best, we obtain a larger pruned tree with lower classification accuracy:

prune.carseats=prune.misclass(tree.carseats,best=8)
plot(prune.carseats)
text(prune.carseats,pretty=0)

tree.pred=predict(prune.carseats,Carseats.test,type="class")
table(tree.pred,High.test)
##          High.test
## tree.pred No Yes
##       No  89  21
##       Yes 28  62
(89+62)/200 #75.5% accuracy
## [1] 0.755
confusionMatrix(tree.pred,High.test)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction No Yes
##        No  89  21
##        Yes 28  62
##                                           
##                Accuracy : 0.755           
##                  95% CI : (0.6894, 0.8129)
##     No Information Rate : 0.585           
##     P-Value [Acc > NIR] : 3.611e-07       
##                                           
##                   Kappa : 0.5015          
##                                           
##  Mcnemar's Test P-Value : 0.3914          
##                                           
##             Sensitivity : 0.7607          
##             Specificity : 0.7470          
##          Pos Pred Value : 0.8091          
##          Neg Pred Value : 0.6889          
##              Prevalence : 0.5850          
##          Detection Rate : 0.4450          
##    Detection Prevalence : 0.5500          
##       Balanced Accuracy : 0.7538          
##                                           
##        'Positive' Class : No              
## 

Now let’s try to fit a random forest model for this data set.

library(randomForest)

set.seed(1)
rf.carseats=randomForest(High~.-Sales,data=Carseats,subset=train,mtry=3,importance=TRUE)
yhat.rf = predict(rf.carseats,newdata=Carseats.test)

confusionMatrix(yhat.rf,High.test)
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction  No Yes
##        No  110  24
##        Yes   7  59
##                                           
##                Accuracy : 0.845           
##                  95% CI : (0.7873, 0.8922)
##     No Information Rate : 0.585           
##     P-Value [Acc > NIR] : 1.939e-15       
##                                           
##                   Kappa : 0.671           
##                                           
##  Mcnemar's Test P-Value : 0.004057        
##                                           
##             Sensitivity : 0.9402          
##             Specificity : 0.7108          
##          Pos Pred Value : 0.8209          
##          Neg Pred Value : 0.8939          
##              Prevalence : 0.5850          
##          Detection Rate : 0.5500          
##    Detection Prevalence : 0.6700          
##       Balanced Accuracy : 0.8255          
##                                           
##        'Positive' Class : No              
## 

Here we can see nuch improved accuracy rates for the Random Forest model.